# _*_ coding: utf-8 _*_
# @Date : 2023/3/20 23:46
# @Author : Paul
# @File : logistic_regression.py
# @Description :
from classifiers.classifier import Classifier

from core.utils.log_util import LogUtil

from core.beans.classifier_result import ClassifierResult

from core.utils.date_util import DateUtil
from sklearn.linear_model import LogisticRegression
from core.utils.string_utils import StringUtils

import json
import sys


class TWLogisticRegression(Classifier):

    def __init__(self,
                 app_name="clusters",
                 data_source_id=None,
                 table_name=None,
                 feature_cols=None,
                 class_col=None,
                 train_size=None,
                 param=None
                 ):
        """
        初始化
        :param app_name:
        :param data_source_id:
        :param table_name:
        :param feature_cols:
        :param class_col:
        :param train_size:
        """
        super(TWLogisticRegression, self).__init__(app_name=app_name,
                                           data_source_id=data_source_id,
                                           table_name=table_name,
                                           feature_cols=feature_cols,
                                           class_col=class_col,
                                           train_size=train_size,
                                           param=param)
        self.IS_MODEL_EVAL = True  # 默认：不需要评估模型
        # 预测效果分布图
        self.tree_pred_image = "descion_tree_pred_" + app_name + "_" + DateUtil.getCurrentDateSimple()

    def initModel(self):
        """
        初始化模型
        """
        algoParam = self.param["algoParam"]

        penalty = "l2" if StringUtils.isBlack(algoParam["penalty"]) else algoParam["penalty"]
        C = 1.0 if StringUtils.isBlack(algoParam["paramC"]) else float(algoParam["paramC"])
        nax_iter = 100 if StringUtils.isBlack(algoParam["maxIter"]) else int(algoParam["maxIter"])
        multi_class = 'auto' if StringUtils.isBlack(algoParam["multiClass"]) else str(algoParam["multiClass"])
        solver =  'lbfgs' if StringUtils.isBlack(algoParam["solver"]) else str(algoParam["solver"])
        class_weight =  None if StringUtils.isBlack(algoParam["classWeight"]) else dict(algoParam["classWeight"])
        n_jobs = None if StringUtils.isBlack(algoParam["nJobs"]) else int(algoParam["nJobs"])

        self.clf = LogisticRegression(penalty=penalty,
                                      C=C,
                                      max_iter=nax_iter,
                                      multi_class=multi_class,
                                      solver=solver,
                                      class_weight=class_weight,
                                      n_jobs=n_jobs)

    def buildModel(self, train_data):
        """
        训练模型
        """
        Xtrain = train_data[0]
        Ytrain = train_data[1]
        self.clf = self.clf.fit(Xtrain, Ytrain)

    def evalModel(self, train_data, test_data):
        """
        评估模型
        """
        Xtest = test_data[0]
        Ytest = test_data[1]
        score_ = self.clf.score(Xtest, Ytest)
        var_importance = {}
        for i in range(len(self.feature_cols)):
            var_importance[self.feature_cols[i]] = self.clf.coef_[0][i]

        # import graphviz
        # from sklearn import tree
        # dot_data = tree.export_graphviz(self.clf,
        #                                 feature_names=self.feature_cols,
        #                                 class_names=self.labels,
        #                                 filled=True,
        #                                 rounded=True)
        # graph = graphviz.Source(dot_data)
        # graph.view(filename=self.tree_pred_image, directory=self.image_path)
        # # graph.save(self.cluster_pred_image)

        # 结束时间
        end_time = DateUtil.getCurrentDate()
        cost_second = DateUtil.diffMin(self.start_time, end_time)

        # 模型结果存入mysql
        algo_result = ClassifierResult(self.param["id"],
                                       "logistic_regression",
                                       self.param,
                                       self.app_name,
                                       self.info,
                                       self.describe,
                                       None,
                                       var_importance,
                                       score_,
                                       "sucess",
                                       self.start_time,
                                       end_time,
                                       cost_second)
        LogUtil.saveClassifierResult(self.meta_data_source, algo_result)

if __name__ == '__main__':
    # cluster = DecisionTree(app_name="cluster_demo",
    #                   data_source_id=9,
    #                   table_name="titanic",
    #                   train_cols=["Pclass", "Sex", "Age", "SibSp","Parch",  "Fare", "Embarked"],
    #                   class_col=["Survived"],
    #                   train_size=0.7)
    # cluster.execute()
    argv = sys.argv[1]
    # argv = "{\"algoParam\":{\"classWeight\":\"\",\"maxIter\":\"\",\"multiClass\":\"auto\",\"nJobs\":\"\",\"paramC\":\"\",\"penalty\":\"l2\",\"solver\":\"lbfgs\"},\"appName\":\"kmeans_1\",\"classCols\":\"target\",\"dataSourceId\":\"9\",\"featureCols\":\"mean_radius,mean_texture,mean_perimeter,mean_area,mean_smoothness\",\"id\":\"1679627149440\",\"preProcessMethodList\":[{\"preProcessMethod\":\"deletena\"}],\"tableName\":\"breast_cancer\",\"trainDataRatio\":\"0.7\"}"
    param = json.loads(argv)
    app_name = param["appName"]
    data_source_id = param["dataSourceId"]
    table_name = param["tableName"]
    feature_cols = param["featureCols"]
    class_cols = param["classCols"]
    train_size = float(param["trainDataRatio"])
    class_cols_list = []
    if isinstance(class_cols, list):
        class_cols_list = class_cols
    else:
        class_cols_list.append(class_cols)
    classifier = TWLogisticRegression(app_name=app_name,
                              data_source_id=data_source_id,
                              table_name=table_name,
                              feature_cols=feature_cols,
                              class_col=class_cols_list,
                              train_size=train_size,
                              param=param)
    if "paramTrain" not in param.keys():
        classifier.execute()
    else:
        classifier.paramTrain()